%This script implements preliminary steps that are shared by all simulation codes
%Corresponding author: Rodrigo Adao
%Date: 09/11/2024
%Input: model_data_fgkk.mat

%% Set parameters

%Numerical parameters
tol_parm = 1e-4;    %tolerance level to checks in parameter inversion
tol_conv = 1e-7;    %tolerance level to check for price convergence
adj_inner = .4;     %parameter to adjust sectoral domestic price in inner loop
adj_outter = .8;    %parameter to adjust sectoral import price in outer loop (only matters when omega_star != 0)
adj_inner_g0 = .1;
adj_outter_g0 = .8;

%Testing parameters
alpha = 0.05;
critical = norminv(1 - alpha/2,0,1);

%Set elasticities in the researcher's model
    omega_star = 0;  %foreign export supply elasticity (FGKK est = -.002); 
    sigma = 2.53;       %elasticity of substitution across varieties (FGKK est = 2.53)
    eta = 1.53;         %elasticity of substitution across products (FGKK est = 1.53)
    kappa = 1.19;       %elasticity of substitution across domestic/foreign composites (FGKK est = 1.19)
    sigma_star = 1.04;  %foreign import demand elasticity (FGKK est = 1.04)
    nu =  1;            %elasticity of subsitution across sector intermediate good (FGKK set nu = 1)
    epsilon =  1;       %elasticity of subsitution across sector final good (FGKK set epsilon = 1)
    DX = 0;             %DX = 0 for export tax revenue to be accured to foreign, and DX = 1 to be accrued to Home
    DM = 1;             %DM = 0 for import tax revenue to be accured to foreign, and DM = 1 to be accrued to Home
    gammaM = 1;
    gammaX = 1;
    gammaQ = 1;
    rhoX = 1;
    rhoM = 1;

if exist('sigma_alt') == 1
    sigma = sigma_alt;
end

%% Import and adjust data

load(data_path + "model_data_fgkk.mat")

%Vector of dummies linking product to sector
    [N, G] = size(x_val);
    Dsg = zeros(S,G);
    for s=1:S
       Dsg(s,:) = ( hs10_naics(:,end) == naics(s) )';
    end
    Dsg = sparse(Dsg);

%Adjust trade values to have the same value as in the IO table
    m_val_sample = sparse(m_val);
    adj_m_s = PMs_by_Ms./( Dsg*sum(m_val_sample,1)' );
    adj_m_gs = spdiags( ( adj_m_s'*Dsg)' , 0, G, G);
    m_val_adj = m_val_sample*adj_m_gs;

    x_val_sample = sparse(x_val);
    adj_x_s = EXs./( Dsg*sum(x_val_sample,1)' );
    adj_x_gs = spdiags( ( adj_x_s'*Dsg)' , 0, G, G);
    x_val_adj = x_val_sample*adj_x_gs;
    
%Define variables
    x_ig_0 = x_val_adj;
    m_ig_0 = m_val_adj;
    Lsr = sparse( Lsr );

    tau_star_ig_0 = sparse(x_tf_counter_initial);
    tau_star_ig_1 = sparse(x_tf_counter_final);    

    tau_ig_0 = sparse(m_tf_counter_initial);
    tau_ig_1 = sparse(m_tf_counter_final);    

    [R, S] = size(Lsr);

clearvars m_q1* m_val* x_q1* x_val* adj_x* adj_m*

%%Compute minimum levels for targeting trade share
m_ig_0_sort =  sort( reshape( m_ig_0' , N*G, 1) , 'descend') ;
m_ig_0_cumsum = cumsum(m_ig_0_sort)/sum(m_ig_0, 'all');
[min_val, min_loc] = min(abs(m_ig_0_cumsum - trade_target)); 
min_imp = m_ig_0_sort(min_loc);

x_ig_0_sort =  sort( reshape( x_ig_0' , N*G, 1) , 'descend') ;
x_ig_0_cumsum = cumsum(x_ig_0_sort)/sum(x_ig_0, 'all');
[min_val, min_loc] = min(abs(x_ig_0_cumsum - trade_target)); 
min_exp = x_ig_0_sort(min_loc);

%select sample
    %Sample of shifts
    ind_Mshifts_mat = m_ig_0 > min_imp; 
    ind_Xshifts_mat = x_ig_0 > min_exp;
    
    ind_Mshifts_vec = reshape( (  ind_Mshifts_mat  )', N*G,1);
    ind_Xshifts_vec = reshape( (  ind_Xshifts_mat  )', N*G,1);
    
    NMs = sum(ind_Mshifts_vec);
    NXs = sum(ind_Xshifts_vec);
    exp_sind = [zeros(NMs,1); ones(NXs,1)];
    
    %Sample for estimation
    ind_Msample_mat = m_ig_0 > min_imp;
    ind_Xsample_mat = x_ig_0 > min_exp;
    ind_Msample_ig = reshape( (  ind_Msample_mat  )', N*G,1);
    ind_Xsample_ig = reshape( (  ind_Xsample_mat  )', N*G,1);
    
    NM = sum(ind_Msample_ig)
    NX = sum(ind_Xsample_ig)

    indM = [ones(NM,1); zeros(NM,1); zeros(NX,1)];
    indT = [zeros(NM,1); ones(NM,1); zeros(NX,1)];
    indX = [zeros(NM,1); zeros(NM,1); ones(NX,1)];
    
    tradestat_sample = [sum(m_ig_0(ind_Msample_mat),'all')/sum(m_ig_0,'all'), ...
    sum(tau_ig_0(ind_Msample_mat).*m_ig_0(ind_Msample_mat),'all')/sum(tau_ig_0.*m_ig_0,'all'), ... 
    sum(x_ig_0(ind_Xsample_mat),'all')/sum(x_ig_0,'all')]  

    %Variables for sample
    taustar_ig_0_vec = reshape( (  tau_star_ig_0  )', N*G,1);
    tau_ig_0_vec = reshape( (  tau_ig_0  )', N*G,1);
    m_ig_0_vec = reshape( (  m_ig_0  )', N*G,1);
    x_ig_0_vec = reshape( (  x_ig_0  )', N*G,1);
    
    m_ig_0_sample = m_ig_0_vec(ind_Msample_ig );
    x_ig_0_sample = x_ig_0_vec(ind_Xsample_ig);
    tau_ig_0_sample = tau_ig_0_vec(ind_Msample_ig );
    taustar_ig_0_sample = taustar_ig_0_vec(ind_Xsample_ig );

    taustar_ig_1_vec = reshape( (  tau_star_ig_1  )', N*G,1);
    tau_ig_1_vec = reshape( (  tau_ig_1  )', N*G,1);
    dtau_vec = tau_ig_1_vec - tau_ig_0_vec;
    dtaustar_vec = taustar_ig_1_vec - taustar_ig_0_vec;

%crosswalk from varieties to their good
    good_ig_vec = reshape( repmat(hs10_naics(:,1)', N, 1)' , N*G, 1);
    
    good_Msample_ig = good_ig_vec(ind_Msample_ig );
    good_Xsample_ig = good_ig_vec(ind_Xsample_ig );
    
    good_Msample = sort(unique(good_Msample_ig)) ;
    GMsample = length(good_Msample);
    
    good_Xsample = sort(unique(good_Xsample_ig)) ;
    GXsample = length(good_Xsample);
    
    D_Msample_g_ig = zeros(GMsample,length(good_Msample_ig));
    for g=1:GMsample
       D_Msample_g_ig(g,:) = ( good_Msample_ig == good_Msample(g) )';
    end
    
    D_Xsample_g_ig = zeros(GXsample,length(good_Xsample_ig));
    for g=1:GXsample
       D_Xsample_g_ig(g,:) = ( good_Xsample_ig == good_Xsample(g) )';
    end

%Estimation IV imports
    variety_list = (1:1:N*G)';
    
    id_Xshifts_vec = variety_list(ind_Xshifts_vec);
    id_Xsample_vec = variety_list(ind_Xsample_ig);
    [sampleX_naive, sampleX_naive_shift_loc] = ismember(id_Xsample_vec, id_Xshifts_vec);
    
    shareIV_X = zeros(NX,NXs);
    for i=1:NX
        if sampleX_naive(i) == 1
            shareIV_X(i,sampleX_naive_shift_loc(i)) = 1;
        end
    end
    shareIV_X = shareIV_X(sampleX_naive,:);
    [NXnaive, ~] = size(shareIV_X);
    
    D_Xsample_g_ig = D_Xsample_g_ig(:,sampleX_naive');
    a=sum(D_Xsample_g_ig,2);
    D_Xsample_g_ig = D_Xsample_g_ig(a>0,:);
    
    id_Mshifts_vec = variety_list(ind_Mshifts_vec);
    id_Msample_vec = variety_list(ind_Msample_ig);
    [sampleM_naive, sampleM_naive_shift_loc] = ismember(id_Msample_vec, id_Mshifts_vec);
    
    shareIV_M = zeros(NM,NMs);
    for i=1:NM
        if sampleM_naive(i) == 1
            shareIV_M(i,sampleM_naive_shift_loc(i)) = 1;
        end
    end
    shareIV_M = shareIV_M(sampleM_naive,:);
    [NMnaive, ~] = size(shareIV_M);
    
    D_Msample_g_ig = D_Msample_g_ig(:,sampleM_naive');
    a=sum(D_Msample_g_ig,2);
    D_Msample_g_ig = D_Msample_g_ig(a>0,:);
    
    shareIV_M_resg   = shareIV_M    - D_Msample_g_ig'*( diag(1./sum(D_Msample_g_ig, 2))*D_Msample_g_ig*shareIV_M    );
    shareIV_X_resg   = shareIV_X    - D_Xsample_g_ig'*( diag(1./sum(D_Xsample_g_ig, 2))*D_Xsample_g_ig*shareIV_X    );
    
    shareIV_tau = [shareIV_M, zeros(NMnaive,NXs); 
                   shareIV_M, zeros(NMnaive,NXs); 
                    zeros(NXnaive,NMs), shareIV_X ];
                
     sample_naive = [sampleM_naive; sampleM_naive; sampleX_naive];

   share_MCg_nGE = [shareIV_M_resg, zeros(NMnaive,NXs); 
                    shareIV_M_resg, zeros(NMnaive,NXs); 
                    zeros(NXnaive,NMs), shareIV_X_resg ];
    
   share_MCg_nGE = share_MCg_nGE/mean( sum( share_MCg_nGE.^2 , 2) );
   shareM_MCg_nGE = share_MCg_nGE(indM ==1, exp_sind==0);

   clearvars shareIV_M shareIV_X shareIV_M_resg shareIV_X_resg m_ig_0_sort x_ig_0_sort m_ig_0_cumsum x_ig_0_cumsum m_ig_0_vec x_ig_0_vec ...
              good_ig_vec D_Msample_s_ig D_Xsample_s_ig sec_ig_vec variety_list

    %% Compute matrix of shares in the researcher's model given initial conditions
    
%Invert parameters
    m_g_0 = sum( (1+tau_ig_0).*m_ig_0, 1)'; 
    m_s_0 = Dsg*m_g_0;

    [delta_ig_0, a_star_ig_0, z_star_ig_0, a_ig_0, aM_g_0, AM_s_0, AD_s_0, betaT_0, beta_s_0, alphaL_s_0, alphaI_s_0, alpha_ks_0, barL_rs_0, Z_rs_0, D_0, E_s_0, F_0] = ... 
        invert_param_full(x_ig_0, m_ig_0, tau_star_ig_0, tau_ig_0, trade_con_share, tot_labor_comp, tot_intermediate, sales_s, IO_sales, Lsr, ...
        Dsg, DX, DM, omega_star, sigma, eta, kappa, sigma_star, gammaM, gammaX,  tol_parm);

    p_s_0 = ones(S,1);
    
%compute shares
    [p_s_IV, PM_s_IV, ...
             rho_qm_ig_tau_ig, rho_qm_ig_taustar_ig, rho_qm_ig_a_ig, rho_qm_ig_zstar_ig, rho_qm_ig_astar_ig, rho_qm_ig_delta_ig, ...
             rho_pm_ig_tau_ig, rho_pm_ig_taustar_ig, rho_pm_ig_a_ig, rho_pm_ig_zstar_ig, rho_pm_ig_astar_ig, rho_pm_ig_delta_ig, ...
             rho_px_ig_tau_ig, rho_px_ig_taustar_ig, rho_px_ig_a_ig, rho_px_ig_zstar_ig, rho_px_ig_astar_ig, rho_px_ig_delta_ig, ...
             rho_pstar_ig_tau_ig, rho_pstar_ig_taustar_ig, rho_pstar_ig_a_ig, rho_pstar_ig_zstar_ig, rho_pstar_ig_astar_ig, rho_pstar_ig_delta_ig, ...
             rho_rm_ig_tau_ig, rho_rm_ig_taustar_ig, rho_rm_ig_a_ig, rho_rm_ig_zstar_ig, rho_rm_ig_astar_ig, rho_rm_ig_delta_ig] ...
             = compute_equilibrium_foa_full( delta_ig_0, a_star_ig_0, z_star_ig_0, a_ig_0, aM_g_0, AM_s_0, AD_s_0, betaT_0, beta_s_0, alphaL_s_0, alphaI_s_0, alpha_ks_0, barL_rs_0, Z_rs_0, D_0, ... 
             tau_ig_0, tau_star_ig_0, m_ig_0, E_s_0, p_s_0, Dsg, DX, DM, omega_star, sigma, eta, kappa, sigma_star, nu, epsilon, gammaQ, gammaX, gammaM, rhoX, rhoM, tol_conv, adj_inner_g0, adj_outter_g0, ... 
             ind_Mshifts_vec, ind_Xshifts_vec, ind_Msample_ig, ind_Xsample_ig );  
         
    share_pm    = [rho_pm_ig_tau_ig    , rho_pm_ig_taustar_ig   ];
    share_px    = [rho_px_ig_tau_ig    , rho_px_ig_taustar_ig   ];
    share_rm    = [rho_rm_ig_tau_ig    , rho_rm_ig_taustar_ig   ];
    shareMOD_dW    = sparse( [share_pm; share_rm; share_px] );

clearvars rho_* share_qm share_pm share_pstar share_px share_rm ...
             delta_ig_0 a_star_ig_0 z_star_ig_0 a_ig_0 aM_g_0 AM_s_0 AD_s_0 betaT_0 beta_s_0 alphaL_s_0 alphaI_s_0 alpha_ks_0 barL_rs_0 Z_rs_0 ...
             tau_ig_1 tau_star_ig_1 tau_ig_0_vec taustar_ig_0_vec tau_ig_1_vec taustar_ig_1_vec

%% Define weight vector aand control matrices

%welfare weights -- omega in notation
    
    %Vector of welfare weights
    ww_vm = 100*(1+tau_ig_0_sample).*m_ig_0_sample/F_0;
    ww_px = 100*(x_ig_0_sample/F_0); 
    ww_dW = [- ww_vm; ww_vm; ww_px];
    
    %weight matrices
    ww_dWmat = diag(ww_dW);
    ww_dWsign = [- (1/NM)*ones(NM,1); (1/NM)*ones(NM,1); (1/NX)*ones(NX,1)];
    ww_dWsign_mat = diag(ww_dWsign);

%Control vectors
    [Nobs, Nsh] = size(shareMOD_dW);
    shareMOD_dWbar = sum(shareMOD_dW,2);    
    Ca = [ones(Nobs,1), indX, indM];
    C = [ones(Nobs,1), shareMOD_dWbar, indX, indM];

%Compute matrices for OLS coefs    
    MCa_term1 = (Ca'*Ca)\Ca';
    MC_term1 = (C'*C)\C';

%Naive subsample
    Cn = C;
    Cn = Cn(sample_naive,:);
    MCn_term1=(Cn'*Cn)\Cn';

    ww_dWmat = diag(ww_dW);
    ww_dWmat_naive = diag(ww_dW(sample_naive));    
